from typing import Dict

import torch
from torch import nn
import torch.nn.functional as F

from omegaconf import DictConfig


def train_step_autoencoder(
    cfg: DictConfig,
    model: nn.Module,
    xs: Dict[str, torch.Tensor],
    condition: Dict[str, torch.Tensor],
    idx_data: Dict[str, torch.Tensor],
    geometry: Dict[str, torch.Tensor],
    loss_wrap: nn.Module,
    progress_remaining: float,
):
    model.train()
    # model prediction
    # for ae we only use df
    x_preds = model(xs["df"], condition=condition)

    # compute losses
    # TODO(diff) get rid of loss_wrap?
    # loss = F.mse_loss(x_preds["df"], xs["df"])
    # losses = {"df": loss}
    loss, losses = loss_wrap(
        x_preds,
        xs,  # autoencoder
        idx_data,
        geometry=geometry,
        progress_remaining=progress_remaining,
        separate_zf=(
            cfg.dataset.separate_zf if cfg.autoencoder.extra_zf_loss else False
        ),  # TODO(diff) which options?
        integral_loss_type=getattr(cfg.training, "integral_loss_type", "mse"),
    )
    return loss, losses


def train_step_peft(
    cfg: DictConfig,
    model: nn.Module,
    xs: Dict[str, torch.Tensor],
    condition: Dict[str, torch.Tensor],
    idx_data: Dict[str, torch.Tensor],
    geometry: Dict[str, torch.Tensor],
    loss_wrap: nn.Module,
    progress_remaining: float,
):
    model.train()
    x_preds = model(xs["df"], condition=condition)

    loss, losses = loss_wrap(
        x_preds,
        xs,
        idx_data,
        geometry=geometry,
        progress_remaining=progress_remaining,
        separate_zf=(
            cfg.dataset.separate_zf
            if hasattr(cfg.autoencoder, "extra_zf_loss")
            and cfg.autoencoder.extra_zf_loss
            else False
        ),
        integral_loss_type=getattr(cfg.training, "integral_loss_type", "mse"),
    )

    return loss, losses
